package vpcserver

import (
	"context"
	"crypto/tls"
	"encoding/json"
	"errors"
	"fmt"
	"net"
	"strconv"
	"strings"
	"sync"
	"time"
)

var VpcLogger func(msg string)

const HeartbeatInterval = time.Duration(7) * time.Second

type VpcClient struct {
	ProxyId   string
	TunnelMap sync.Map   //隧道连接缓存
	Lock      sync.Mutex //全局锁
}

type TunnelInfo struct {
	proxyid           string
	server            string //隧道服务器地址
	ticket            string
	assetConMap       sync.Map    //资产链接缓存
	conn              net.Conn    //隧道连接
	ioChan            chan []byte //隧道消息通道
	tcpBufSize        int         //TCP转发的缓冲区大小 默认128K
	maxIdealTimeOut   int         //目标资产tcp链接最大空闲时间，默认5分钟
	lastHeartBeatTime time.Time
}

type AssetConnWrapper struct {
	AssetConn     net.Conn  //客户端资产会话链接
	AssetAddr     string    //目标资产
	ClientAddr    string    //云端客户端对应资产
	LastVisitTime time.Time //最后一次转发报文的时间
}

func (vpcClient *VpcClient) Start(tunInfo *TunnelInfo) {
	ctx, cancel := context.WithCancel(context.Background())
	defer func() {
		if e := recover(); e != nil {
			PrintVpcLog(fmt.Sprintf("Start Panicing %s\n", e))
		}
		cancel()
		tunInfo.conn.Close()
		vpcClient.TunnelMap.Delete(strings.Split(tunInfo.ticket, "@")[0] + "@" + tunInfo.server)
	}()
	go tunInfo.keepAlive(ctx)
	go tunInfo.SendMessages(ctx)
	go tunInfo.clearUnUsedProxy(ctx)
	go tunInfo.clearUnUsedAssetConn(ctx)
	for {
		payLoad, err := receive(tunInfo.conn)
		if err != nil {
			PrintVpcLog("receive data err！ProxyId:" + vpcClient.ProxyId + " LocalAddr:" + tunInfo.conn.LocalAddr().String() + "  " + err.Error())
			return
		}
		if payLoad.command[0] == VpcAgentResponse[0] {
			rst := &AgentDto{}
			err = json.Unmarshal(payLoad.head, rst)
			if err != nil {
				return
			}
			switch rst.Action {
			case Data:
				//vpcServer 对应的本次请求的客户端ID
				value, ok := tunInfo.assetConMap.Load(rst.ClientAddr)
				//如果已经有客户端资产链接缓存
				if ok {
					wrapperConn := value.(*AssetConnWrapper)
					_, err = wrapperConn.AssetConn.Write(payLoad.body)
					//写入资产数据发生错误，必须通知server，关闭
					if err != nil {
						_ = wrapperConn.AssetConn.Close()
						tunInfo.assetConMap.Delete(rst.ClientAddr)
						tunInfo.sendClosed(rst)
						continue
					}
					//更新次业务链接的最后一次访问时间
					wrapperConn.LastVisitTime = time.Now()
					break
				}
				//如果是第一次向该资产客户端发送数据,默认30秒延迟
				con, err := net.DialTimeout("tcp", rst.AssetAddr, time.Duration(30)*time.Second)
				if err != nil {
					tunInfo.sendClosed(rst)
					continue
				}
				go tunInfo.handlerAssetConn(con, rst)
				wrapperConn := &AssetConnWrapper{AssetConn: con, AssetAddr: rst.AssetAddr, ClientAddr: rst.ClientAddr, LastVisitTime: time.Now()}
				tunInfo.assetConMap.Store(rst.ClientAddr, wrapperConn)
				_, err = wrapperConn.AssetConn.Write(payLoad.body)
				if err != nil {
					_ = wrapperConn.AssetConn.Close()
					tunInfo.assetConMap.Delete(rst.ClientAddr)
					tunInfo.sendClosed(rst)
					continue
				}
				//更新次业务链接的最后一次访问时间
				wrapperConn.LastVisitTime = time.Now()
				break
			case KeepAlive:
				tunInfo.lastHeartBeatTime = time.Now()
				break
			}
		} else {
			PrintVpcLog("this is an illegal request,close the channel!")
			return
		}
	}
}

func NewClient(proxyId string) *VpcClient {
	vpcClient := &VpcClient{ProxyId: proxyId}
	return vpcClient
}
func PrintVpcLog(msg string) {
	if VpcLogger != nil {
		VpcLogger(msg)
	}
}

// CreateTunnel 创建连接
func (vpcClient *VpcClient) CreateTunnel(serverAddr, ticket, token string, tcpBufSize, maxIdealTimeOut, retryTime int) error {
	vpcClient.Lock.Lock()
	defer func() {
		if e := recover(); e != nil {
			PrintVpcLog(fmt.Sprintf("CreateTunnel Panicing %s\n", e))
		}
		vpcClient.Lock.Unlock()
	}()
	if retryTime == 0 {
		retryTime = 3
	}
	var i = 0
	tlsConfig := &tls.Config{InsecureSkipVerify: true}
	//连接server端服务
RECONNECT:
	if i >= retryTime {
		return errors.New("the connection has exceeded the maximum number of times: " + strconv.Itoa(i))
	}
	//寻找合适的server地址，目前默认按照连通性进行测试
	b, server := loadBalanceServer(strings.Split(serverAddr, ","))
	if !b {
		PrintVpcLog(fmt.Sprintf("can't find the active VpcServer:%s", serverAddr))
		return errors.New("can't find the active VpcServer")
	}
	//如果已经连接OK，就不用再连接
	_, ok := vpcClient.TunnelMap.Load(strings.Split(ticket, "@")[0] + "@" + server)
	if ok {
		return nil
	}
	conn, err := tls.DialWithDialer(&net.Dialer{KeepAlive: time.Minute * time.Duration(30), Timeout: time.Duration(30) * time.Second}, "tcp", server, tlsConfig)
	if err != nil {
		return err
	}
	//验证登录
	err = vpcClient.login(conn, ticket, token)
	if err != nil {
		_ = conn.Close()
		i++
		goto RECONNECT
	}
	PrintVpcLog(fmt.Sprintf("login vpcserver success serverAddr:%s  proxyId:%s\n", server, vpcClient.ProxyId))
	tunInfo := &TunnelInfo{
		proxyid:           vpcClient.ProxyId,
		conn:              conn,
		ioChan:            make(chan []byte, 100),
		lastHeartBeatTime: time.Now(),
		server:            server,
		ticket:            ticket,
	}
	if tcpBufSize == 0 {
		tunInfo.tcpBufSize = DefaultTcpBufSize
	} else {
		tunInfo.tcpBufSize = tcpBufSize
	}
	if maxIdealTimeOut == 0 {
		tunInfo.maxIdealTimeOut = DefaultAssetMaxIdealTimeOut
	} else {
		tunInfo.maxIdealTimeOut = maxIdealTimeOut
	}
	vpcClient.TunnelMap.Store(strings.Split(ticket, "@")[0]+"@"+server, tunInfo)
	go vpcClient.Start(tunInfo)
	return nil
}
func (tunInfo *TunnelInfo) SendMessages(ctx context.Context) {
	defer func() {
		if e := recover(); e != nil {
			PrintVpcLog(fmt.Sprintf("SendMessages Panicing %s\n", e))
		}
	}()
	for {
		select {
		case <-ctx.Done():
			return
		case msg, ok := <-tunInfo.ioChan:
			if ok {
				_, err := tunInfo.conn.Write(msg)
				if err != nil {
					return
				}
			}
		}
	}
}
func (tunInfo *TunnelInfo) keepAlive(ctx context.Context) {
	ticker := time.NewTicker(HeartbeatInterval)
	defer func() {
		if e := recover(); e != nil {
			PrintVpcLog(fmt.Sprintf("keepAlive Panicing %s\n", e))
		}
		ticker.Stop()
	}()
	dto := &AgentDto{
		Action:  KeepAlive,
		ProxyId: tunInfo.proxyid,
	}
	head, _ := json.Marshal(dto)
	for {
		select {
		case <-ctx.Done():
			return
		case <-ticker.C:
			tunInfo.ioChan <- PayLoadEncode(VpcAgentRequest, head, nil)
		}
	}
}
func (tunInfo *TunnelInfo) sendClosed(rst *AgentDto) {
	defer func() {
		if e := recover(); e != nil {
			PrintVpcLog(fmt.Sprintf("sendClosed Panicing %s\n", e))
		}
	}()
	rst.Action = CloseConn
	head, _ := json.Marshal(rst)
	encode := PayLoadEncode(VpcAgentRequest, head, nil)
	tunInfo.ioChan <- encode
}
func (tunInfo *TunnelInfo) clearUnUsedProxy(ctx context.Context) {
	ticker := time.NewTicker(HeartbeatInterval)
	defer ticker.Stop()
	for {
		select {
		case <-ctx.Done():
			return
		case <-ticker.C:
			if time.Since(tunInfo.lastHeartBeatTime) > time.Minute*time.Duration(3) {
				tunInfo.conn.Close()
			}
		}
	}
}
func (tunInfo *TunnelInfo) clearUnUsedAssetConn(ctx context.Context) {
	ticker := time.NewTicker(time.Second * time.Duration(30))
	defer ticker.Stop()
	for {
		select {
		case <-ctx.Done():
			return
		case <-ticker.C:
			tunInfo.assetConMap.Range(func(key, value any) bool {
				wrapper, ok := value.(*AssetConnWrapper)
				if ok {
					//超过3分钟没人访问的业务入口，进行关闭
					if time.Now().Sub(wrapper.LastVisitTime) > time.Minute*time.Duration(tunInfo.maxIdealTimeOut) {
						_ = wrapper.AssetConn.Close()
						tunInfo.assetConMap.Delete(key.(string))
					}
				}
				return true
			})
		}
	}
}

// 开启一个协程，读取目标资产主机数据回包
func (tunInfo *TunnelInfo) handlerAssetConn(assetConn net.Conn, rst *AgentDto) {
	defer func() {
		if e := recover(); e != nil {
			PrintVpcLog(fmt.Sprintf("handlerAssetConn Panicing %s\n", e))
		}
		_ = assetConn.Close()
		tunInfo.assetConMap.Delete(rst.ClientAddr)
	}()
	buf := make([]byte, 1024*tunInfo.tcpBufSize)
	head, _ := json.Marshal(rst)
	for {
		n, err := assetConn.Read(buf)
		if err != nil {
			tunInfo.sendClosed(rst)
			return
		}
		if n > 0 {
			encode := PayLoadEncode(VpcAgentRequest, head, buf[:n])
			tunInfo.ioChan <- encode
		}
	}
}

// 登录握手
func (vpcClient *VpcClient) login(remoteConn net.Conn, ticket, token string) error {
	dto := &AgentDto{
		Action:  Register,
		Token:   token,
		ProxyId: vpcClient.ProxyId,
		Ticket:  ticket,
	}
	head, _ := json.Marshal(dto)
	_, err := remoteConn.Write(PayLoadEncode(VpcAgentRequest, head, nil))
	if err != nil {
		return errors.New("login has failed")
	}
	payload, err := receive(remoteConn)
	if err != nil {
		return errors.New("login has failed")
	}
	if payload.command[0] == VpcAgentResponse[0] {
		c := &AgentDto{}
		err = json.Unmarshal(payload.head, c)
		if err != nil {
			return errors.New("login has failed")
		}
		//登录注册成功
		if c.Action == Register {
			if c.Success {
				return nil
			} else {
				return errors.New("login has failed")
			}
		}
	}
	return err
}
func loadBalanceServer(adds []string) (bool, string) {
	for _, server := range adds {
		b, _ := TelNetIp(server)
		if b {
			return true, server
		}
	}
	return false, ""
}
